from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from datetime import datetime, date
from typing import List, Optional, Dict, Any
import logging

import database, schemas
from cache import cached, CacheManager

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 题目相关操作
@cached(ttl=600, key_prefix="questions")
def get_questions(db: Session) -> List[database.Question]:
    """获取所有题目（带缓存）"""
    try:
        questions = db.query(database.Question).all()
        logger.info(f"获取到 {len(questions)} 道题目")
        return questions
    except Exception as e:
        logger.error(f"获取题目失败: {str(e)}")
        raise

@cached(ttl=600, key_prefix="question")
def get_question(db: Session, question_id: int) -> Optional[database.Question]:
    """根据ID获取单个题目（带缓存）"""
    try:
        if question_id <= 0:
            raise ValueError("题目ID必须大于0")
        
        question = db.query(database.Question).filter(database.Question.id == question_id).first()
        if not question:
            logger.warning(f"未找到ID为 {question_id} 的题目")
        return question
    except Exception as e:
        logger.error(f"获取题目失败 (ID: {question_id}): {str(e)}")
        raise

def create_question(db: Session, question: schemas.QuestionCreate) -> database.Question:
    """创建新题目"""
    try:
        # 验证输入数据
        if not question.question_content or not question.question_content.strip():
            raise ValueError("题目内容不能为空")
        if not question.answer or question.answer.strip() not in ['A', 'B', 'C', 'D']:
            raise ValueError("答案必须是A、B、C、D中的一个")
        
        # 获取默认题库ID
        default_bank = db.query(database.QuestionBank).filter(database.QuestionBank.name == '默认题库').first()
        if not default_bank:
            # 如果没有默认题库，创建一个
            default_bank = database.QuestionBank(
                name='默认题库',
                description='系统自动创建的默认题库',
                created_by='system',
                created_time=datetime.now(),
                updated_time=datetime.now()
            )
            db.add(default_bank)
            db.commit()
            db.refresh(default_bank)
        
        db_question = database.Question(
            question_content=question.question_content.strip(),
            option_a=question.option_a.strip() if question.option_a else "",
            option_b=question.option_b.strip() if question.option_b else "",
            option_c=question.option_c.strip() if question.option_c else "",
            option_d=question.option_d.strip() if question.option_d else "",
            answer=question.answer.strip().upper(),
            knowledge_point=question.knowledge_point.strip() if question.knowledge_point else "",
            explanation=question.explanation.strip() if question.explanation else "",
            wrong_analysis=question.wrong_analysis.strip() if question.wrong_analysis else "",
            difficulty_level=getattr(question, 'difficulty_level', None),
            topic_category=getattr(question, 'topic_category', None),
            question_bank_id=default_bank.id,  # 设置默认题库ID
            created_time=datetime.now(),
            updated_time=datetime.now()
        )
        
        db.add(db_question)
        db.commit()
        db.refresh(db_question)
        
        # 更新题库的题目数量
        update_question_bank_count(db, default_bank.id)
        
        # 清除相关缓存
        CacheManager.clear_questions_cache()
        
        logger.info(f"成功创建题目 ID: {db_question.id}")
        return db_question
    except Exception as e:
        db.rollback()
        logger.error(f"创建题目失败: {str(e)}")
        raise

def clear_all_questions(db: Session):
    """清空所有题目"""
    try:
        count = db.query(database.Question).count()
        db.query(database.Question).delete()
        db.commit()
        
        # 清除相关缓存
        CacheManager.clear_questions_cache()
        
        logger.info(f"成功清空 {count} 道题目")
    except Exception as e:
        db.rollback()
        logger.error(f"清空题目失败: {str(e)}")
        raise

def update_question(db: Session, question_id: int, question: schemas.QuestionCreate) -> Optional[database.Question]:
    """更新题目"""
    try:
        if question_id <= 0:
            raise ValueError("题目ID必须大于0")
        
        # 验证输入数据
        if not question.question_content or not question.question_content.strip():
            raise ValueError("题目内容不能为空")
        if not question.answer or question.answer.strip() not in ['A', 'B', 'C', 'D']:
            raise ValueError("答案必须是A、B、C、D中的一个")
        
        db_question = db.query(database.Question).filter(database.Question.id == question_id).first()
        if not db_question:
            logger.warning(f"未找到ID为 {question_id} 的题目")
            return None
        
        db_question.question_content = question.question_content.strip()
        db_question.option_a = question.option_a.strip() if question.option_a else ""
        db_question.option_b = question.option_b.strip() if question.option_b else ""
        db_question.option_c = question.option_c.strip() if question.option_c else ""
        db_question.option_d = question.option_d.strip() if question.option_d else ""
        db_question.answer = question.answer.strip().upper()
        db_question.knowledge_point = question.knowledge_point.strip() if question.knowledge_point else ""
        db_question.explanation = question.explanation.strip() if question.explanation else ""
        db_question.wrong_analysis = question.wrong_analysis.strip() if question.wrong_analysis else ""
        db_question.difficulty_level = getattr(question, 'difficulty_level', None)
        db_question.topic_category = getattr(question, 'topic_category', None)
        db_question.updated_time = datetime.now()
        
        db.commit()
        db.refresh(db_question)
        
        # 清除相关缓存
        CacheManager.clear_questions_cache()
        
        logger.info(f"成功更新题目 ID: {question_id}")
        return db_question
    except Exception as e:
        db.rollback()
        logger.error(f"更新题目失败 (ID: {question_id}): {str(e)}")
        raise

def delete_question(db: Session, question_id: int) -> bool:
    """删除题目"""
    try:
        if question_id <= 0:
            raise ValueError("题目ID必须大于0")
        
        db_question = db.query(database.Question).filter(database.Question.id == question_id).first()
        if not db_question:
            logger.warning(f"未找到ID为 {question_id} 的题目")
            return False
        
        db.delete(db_question)
        db.commit()
        
        # 清除相关缓存
        CacheManager.clear_questions_cache()
        
        logger.info(f"成功删除题目 ID: {question_id}")
        return True
    except Exception as e:
        db.rollback()
        logger.error(f"删除题目失败 (ID: {question_id}): {str(e)}")
        raise

# 教师数据看板相关函数
@cached(ttl=300, key_prefix="dashboard")
def get_teacher_dashboard_data(db: Session, class_id: int = None) -> Dict[str, Any]:
    """获取教师数据看板综合数据（带缓存）"""
    try:
        logger.info(f"获取数据看板数据，班级ID: {class_id}")
        
        overview = get_dashboard_overview(db, class_id)
        students_data = get_dashboard_students_data(db, class_id)
        errors_data = get_dashboard_errors_data(db, class_id)
        suggestions_data = get_dashboard_suggestions_data(db, class_id)
        
        result = {
            "overview": overview,
            "students": students_data,
            "errors": errors_data,
            "suggestions": suggestions_data
        }
        
        logger.info(f"成功获取数据看板数据，学生数: {overview.get('total_students', 0)}")
        return result
    except Exception as e:
        logger.error(f"获取数据看板数据失败: {str(e)}")
        # 返回默认数据而不是抛出异常
        return {
            "overview": {"total_students": 0, "total_submissions": 0, "average_score": 0, "pass_rate": 0, "excellent_rate": 0},
            "students": [],
            "errors": {"common_errors": [], "error_distribution": {}, "knowledge_point_errors": []},
            "suggestions": {"priority_knowledge_points": [], "individual_guidance": [], "teaching_strategies": []}
        }

@cached(ttl=300, key_prefix="overview")
def get_dashboard_overview(db: Session, class_id: int = None) -> Dict[str, Any]:
    """获取数据看板概览信息（带缓存）"""
    try:
        # 使用更高效的查询
        query = db.query(database.Submission)
        if class_id:
            query = query.filter(database.Submission.class_id == class_id)
        
        submissions = query.all()
        
        if not submissions:
            return {
                "total_students": 0,
                "total_submissions": 0,
                "average_score": 0,
                "pass_rate": 0,
                "excellent_rate": 0
            }
        
        total_students = len(set(s.student_name for s in submissions))
        total_submissions = len(submissions)
        
        if total_submissions == 0:
            return {
                "total_students": total_students,
                "total_submissions": 0,
                "average_score": 0,
                "pass_rate": 0,
                "excellent_rate": 0
            }
        
        average_score = sum(s.score for s in submissions) / total_submissions
        pass_count = sum(1 for s in submissions if s.score >= 60)
        excellent_count = sum(1 for s in submissions if s.score >= 85)
        
        return {
            "total_students": total_students,
            "total_submissions": total_submissions,
            "average_score": round(average_score, 2),
            "pass_rate": round(pass_count / total_submissions * 100, 2),
            "excellent_rate": round(excellent_count / total_submissions * 100, 2)
        }
    except Exception as e:
        logger.error(f"获取概览数据失败: {str(e)}")
        return {
            "total_students": 0,
            "total_submissions": 0,
            "average_score": 0,
            "pass_rate": 0,
            "excellent_rate": 0
        }

@cached(ttl=300, key_prefix="students_data")
def get_dashboard_students_data(db: Session, class_id: int = None) -> List[Dict[str, Any]]:
    """获取学生个体分析数据（带缓存）"""
    try:
        query = db.query(database.Submission).order_by(database.Submission.submission_time.desc())
        if class_id:
            query = query.filter(database.Submission.class_id == class_id)
        
        submissions = query.all()
        
        students_data = []
        for submission in submissions:
            try:
                students_data.append({
                    "name": submission.student_name or "未知学生",
                    "score": submission.score or 0,
                    "submission_time": submission.submission_time.strftime("%Y-%m-%d %H:%M:%S") if submission.submission_time else "未知时间",
                    "class_id": submission.class_id or 0
                })
            except Exception as e:
                logger.warning(f"处理学生数据时出错: {str(e)}")
                continue
        
        logger.info(f"获取到 {len(students_data)} 条学生数据")
        return students_data
    except Exception as e:
        logger.error(f"获取学生数据失败: {str(e)}")
        return []

def get_dashboard_errors_data(db: Session, class_id: int = None):
    """获取共性错题分析数据"""
    try:
        error_analysis = get_error_analysis(db, class_id)
        return error_analysis
    except Exception:
        return {
            "common_errors": [],
            "error_distribution": {},
            "knowledge_point_errors": []
        }

def get_dashboard_suggestions_data(db: Session, class_id: int = None):
    """获取智能教学建议数据"""
    try:
        suggestions = get_teaching_suggestions(db, class_id)
        return suggestions
    except Exception:
        return {
            "priority_knowledge_points": [],
            "individual_guidance": [],
            "teaching_strategies": []
        }

def export_dashboard_data(db: Session, export_request: schemas.ExportRequest):
    """导出数据看板数据"""
    export_type = getattr(export_request, 'export_type', 'analysis')

    if export_type == "analysis":
        return get_teacher_dashboard_data(db, export_request.class_id)
    elif export_type == "student_data":
        return get_dashboard_students_data(db, export_request.class_id)
    else:
        # 默认返回综合数据
        dashboard_data = get_teacher_dashboard_data(db, export_request.class_id)
        students_data = get_dashboard_students_data(db, export_request.class_id)

        return {
            **dashboard_data,
            'student_performance': students_data.get('students', [])
        }

def batch_create_questions(db: Session, questions: List[schemas.QuestionCreate]):
    """批量创建题目 - 保存到激活的题库中"""
    # 先清空现有题目
    clear_all_questions(db)
    
    # 获取激活的题库
    active_bank = get_active_question_bank(db)
    if not active_bank:
        # 如果没有激活的题库，获取或创建默认题库并激活
        default_bank = db.query(database.QuestionBank).filter(database.QuestionBank.name == '默认题库').first()
        if not default_bank:
            # 如果没有默认题库，创建一个
            default_bank = database.QuestionBank(
                name='默认题库',
                description='系统自动创建的默认题库',
                created_by='system',
                created_time=datetime.now(),
                updated_time=datetime.now(),
                is_active=True
            )
            db.add(default_bank)
            db.commit()
            db.refresh(default_bank)
        else:
            # 激活默认题库
            activate_question_bank(db, default_bank.id)
        active_bank = default_bank
    
    # 批量插入新题目到激活的题库
    db_questions = []
    for question in questions:
        db_question = database.Question(
            question_content=question.question_content,
            option_a=question.option_a,
            option_b=question.option_b,
            option_c=question.option_c,
            option_d=question.option_d,
            answer=question.answer,
            knowledge_point=question.knowledge_point,
            explanation=question.explanation,
            wrong_analysis=question.wrong_analysis,
            question_bank_id=active_bank.id  # 设置激活题库ID
        )
        db_questions.append(db_question)
    
    db.add_all(db_questions)
    db.commit()
    
    # 更新题库的题目数量
    update_question_bank_count(db, active_bank.id)
    
    return db_questions


def ai_generate_questions_to_new_bank(db: Session, questions: List[schemas.QuestionCreate], bank_name: str = None, knowledge_points: str = None, generation_mode: str = "single"):
    """AI生成题目到新题库 - 不覆盖现有题库"""
    # 生成新题库名称
    if not bank_name:
        current_time = datetime.now()
        bank_name = f"AI生成题库_{current_time.strftime('%Y%m%d_%H%M%S')}"

    # 先将当前激活的题库设为非激活状态
    current_active = get_active_question_bank(db)
    if current_active:
        current_active.is_active = False
        db.commit()

    # 构建题库描述，包含知识点信息
    description_parts = [
        f'AI生成的题库，创建时间：{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'
    ]

    if knowledge_points:
        if generation_mode == "batch":
            description_parts.append(f'批量生成模式，涉及知识点：{knowledge_points}')
        else:
            description_parts.append(f'基于知识点：{knowledge_points}')

    description_parts.append(f'共包含 {len(questions)} 道题目')

    # 创建新的题库
    new_bank = database.QuestionBank(
        name=bank_name,
        description='；'.join(description_parts),
        created_by='AI',
        created_time=datetime.now(),
        updated_time=datetime.now(),
        is_active=True  # 新题库设为激活状态
    )
    db.add(new_bank)
    db.commit()
    db.refresh(new_bank)
    
    # 批量插入新题目到新题库
    db_questions = []
    for question in questions:
        db_question = database.Question(
            question_content=question.question_content,
            option_a=question.option_a,
            option_b=question.option_b,
            option_c=question.option_c,
            option_d=question.option_d,
            answer=question.answer,
            knowledge_point=question.knowledge_point,
            explanation=question.explanation,
            wrong_analysis=question.wrong_analysis,
            question_bank_id=new_bank.id  # 设置新题库ID
        )
        db_questions.append(db_question)
    
    db.add_all(db_questions)
    db.commit()
    
    # 更新题库的题目数量
    update_question_bank_count(db, new_bank.id)
    
    # 清除缓存
    CacheManager.clear_questions_cache()
    
    return db_questions, new_bank

# 提交记录相关操作
def create_submission(db: Session, submission: schemas.SubmissionCreate, score: int, client_ip: str = None, total_questions: int = 0, correct_answers: int = 0) -> database.Submission:
    """创建提交记录"""
    try:
        # 验证输入数据
        if not submission.student_name or not submission.student_name.strip():
            raise ValueError("学生姓名不能为空")
        if score < 0 or score > 100:
            raise ValueError("分数必须在0-100之间")
        if submission.class_id <= 0:
            raise ValueError("班级ID必须大于0")
        
        today = date.today().strftime("%Y-%m-%d")
        
        db_submission = database.Submission(
            student_name=submission.student_name.strip(),
            score=score,
            submission_date=today,
            class_id=submission.class_id,
            submission_time=datetime.now(),
            client_ip=client_ip,
            total_questions=total_questions,
            correct_answers=correct_answers
        )
        db.add(db_submission)
        db.commit()
        db.refresh(db_submission)
        return db_submission
    except Exception as e:
        db.rollback()
        logger.error(f"创建提交记录失败: {str(e)}")
        raise

@cached(ttl=300, key_prefix="submissions")
def get_submissions_by_class_and_date(db: Session, class_id: int, target_date: str = None) -> List[database.Submission]:
    """根据班级和日期获取提交记录（带缓存）"""
    try:
        if class_id <= 0:
            raise ValueError("班级ID必须大于0")
        
        if target_date is None:
            target_date = date.today().strftime("%Y-%m-%d")
        
        # 验证日期格式
        try:
            datetime.strptime(target_date, "%Y-%m-%d")
        except ValueError:
            raise ValueError("日期格式必须为YYYY-MM-DD")
        
        submissions = db.query(database.Submission).filter(
            and_(
                database.Submission.class_id == class_id,
                database.Submission.submission_date == target_date
            )
        ).order_by(database.Submission.submission_time.desc()).all()
        
        logger.info(f"获取到 {len(submissions)} 条提交记录 (班级: {class_id}, 日期: {target_date})")
        return submissions
    except Exception as e:
        logger.error(f"获取提交记录失败: {str(e)}")
        raise

# 提交详情相关操作
def create_submission_detail(db: Session, submission_id: int, question_id: int, selected_answer: str, is_correct: int = 0) -> database.SubmissionDetail:
    """创建提交详情"""
    try:
        db_detail = database.SubmissionDetail(
            submission_id=submission_id,
            question_id=question_id,
            selected_answer=selected_answer,
            is_correct=is_correct,
            answer_time=datetime.now()
        )
        db.add(db_detail)
        db.commit()
        db.refresh(db_detail)
        return db_detail
    except Exception as e:
        db.rollback()
        logger.error(f"创建提交详情失败: {str(e)}")
        raise

def batch_create_submission_details(db: Session, submission_id: int, answers: List[schemas.StudentAnswer]):
    """批量创建提交详情"""
    try:
        if submission_id <= 0:
            raise ValueError("提交ID必须大于0")
        if not answers:
            raise ValueError("答案列表不能为空")
        
        # 获取所有题目的正确答案（使用索引优化查询）
        question_ids = [answer.questionId for answer in answers if answer.questionId > 0]
        questions = db.query(database.Question).filter(database.Question.id.in_(question_ids)).all()
        question_map = {q.id: q.answer for q in questions}
        
        db_details = []
        for answer in answers:
            if answer.questionId <= 0:
                logger.warning(f"跳过无效的题目ID: {answer.questionId}")
                continue
            if not answer.selectedAnswer or answer.selectedAnswer.strip() not in ['A', 'B', 'C', 'D']:
                logger.warning(f"跳过无效的答案: {answer.selectedAnswer}")
                continue
            
            # 判断答案是否正确 - 使用与calculate_score相同的逻辑
            correct_answer = question_map.get(answer.questionId)
            is_correct = 0
            if correct_answer and answer.selectedAnswer:
                # 提取首字母进行比较（忽略大小写）
                student_first_char = answer.selectedAnswer.strip()[0].upper() if answer.selectedAnswer.strip() else ""
                correct_first_char = correct_answer.strip()[0].upper() if correct_answer.strip() else ""
                is_correct = 1 if student_first_char == correct_first_char else 0
            
            db_detail = database.SubmissionDetail(
                submission_id=submission_id,
                question_id=answer.questionId,
                selected_answer=answer.selectedAnswer.strip().upper(),
                is_correct=is_correct,
                answer_time=datetime.now()
            )
            db_details.append(db_detail)
        
        if not db_details:
            raise ValueError("没有有效的答案数据")
        
        db.add_all(db_details)
        db.commit()
        
        # 清除相关缓存
        CacheManager.clear_submissions_cache()
        
        logger.info(f"成功创建 {len(db_details)} 条提交详情")
        return db_details
    except Exception as e:
        db.rollback()
        logger.error(f"批量创建提交详情失败: {str(e)}")
        raise

def get_submission_details(db: Session, submission_id: int) -> List[database.SubmissionDetail]:
    """获取提交详情"""
    return db.query(database.SubmissionDetail).filter(
        database.SubmissionDetail.submission_id == submission_id
    ).all()

def check_name_lock_by_ip(db: Session, client_ip: str, student_name: str, time_window_minutes: int = 30) -> tuple:
    """检查同一IP在指定时间窗口内是否被锁定到特定姓名

    Args:
        db: 数据库会话
        client_ip: 客户端IP地址
        student_name: 学生姓名
        time_window_minutes: 时间窗口（分钟）

    Returns:
        tuple: (is_locked, locked_name, remaining_minutes)
        - is_locked: bool, 是否被锁定
        - locked_name: str, 锁定的姓名（如果被锁定）
        - remaining_minutes: int, 剩余锁定时间（分钟）
    """
    try:
        if not client_ip:
            return False, None, 0

        # 计算时间窗口
        time_threshold = datetime.now() - timedelta(minutes=time_window_minutes)

        # 查询同一IP在时间窗口内的第一个提交记录（按时间排序）
        first_submission = db.query(database.Submission).filter(
            database.Submission.client_ip == client_ip,
            database.Submission.submission_time >= time_threshold
        ).order_by(database.Submission.submission_time.asc()).first()

        if first_submission:
            # 有锁定记录
            locked_name = first_submission.student_name

            # 计算剩余时间
            time_diff = datetime.now() - first_submission.submission_time
            remaining_minutes = max(0, time_window_minutes - int(time_diff.total_seconds() / 60))

            if student_name.strip() == locked_name:
                # 输入的是锁定的姓名，允许使用
                return False, locked_name, remaining_minutes
            else:
                # 输入的不是锁定的姓名，不允许使用
                return True, locked_name, remaining_minutes

        # 没有锁定记录
        return False, None, 0

    except Exception as e:
        logger.error(f"检查姓名锁定情况失败: {str(e)}")
        return False, None, 0

def get_recent_submissions_by_ip(db: Session, client_ip: str, time_window_minutes: int = 30) -> List[database.Submission]:
    """获取同一IP在指定时间窗口内的所有提交记录

    Args:
        db: 数据库会话
        client_ip: 客户端IP地址
        time_window_minutes: 时间窗口（分钟）

    Returns:
        List[database.Submission]: 提交记录列表
    """
    try:
        if not client_ip:
            return []

        # 计算时间窗口
        time_threshold = datetime.now() - timedelta(minutes=time_window_minutes)

        # 查询同一IP在时间窗口内的所有提交记录
        submissions = db.query(database.Submission).filter(
            database.Submission.client_ip == client_ip,
            database.Submission.submission_time >= time_threshold
        ).order_by(database.Submission.submission_time.desc()).all()

        return submissions

    except Exception as e:
        logger.error(f"获取最近提交记录失败: {str(e)}")
        return []

def delete_submission(db: Session, submission_id: int) -> bool:
    """删除提交记录及其相关的提交详情"""
    # 先删除提交详情
    db.query(database.SubmissionDetail).filter(
        database.SubmissionDetail.submission_id == submission_id
    ).delete()
    
    # 再删除提交记录
    db_submission = db.query(database.Submission).filter(
        database.Submission.id == submission_id
    ).first()
    
    if db_submission:
        db.delete(db_submission)
        db.commit()
        return True
    return False

# 计分相关函数
def calculate_score(db: Session, answers: List[schemas.StudentAnswer]) -> tuple:
    """计算得分并返回详细结果"""
    if not answers:
        return 0, []

    # 获取学生回答的题目ID列表
    answered_question_ids = [answer.questionId for answer in answers]

    # 获取学生实际回答的题目（按ID排序以保证顺序一致）
    questions = db.query(database.Question).filter(
        database.Question.id.in_(answered_question_ids)
    ).order_by(database.Question.id).all()

    if not questions:
        return 0, []

    # 使用学生实际回答的题目数量来计算分数
    total_questions = len(answers)  # 使用提交的答案数量，确保与前端一致
    points_per_question = 100 / total_questions
    correct_count = 0

    # 创建题目ID到题目的映射
    question_map = {q.id: q for q in questions}

    # 创建学生答案的映射
    answer_map = {answer.questionId: answer.selectedAnswer for answer in answers}

    # 计算得分和详细结果 - 按学生提交的答案顺序处理
    full_results = []
    for answer in answers:  # 遍历学生的答案，而不是数据库中的题目
        question = question_map.get(answer.questionId)
        if not question:
            continue  # 跳过无效的题目ID

        student_answer = answer.selectedAnswer
        # 修改判断逻辑：按首字母匹配而不是完全字符串匹配
        is_correct = False
        if student_answer and question.answer:
            # 提取首字母进行比较（忽略大小写）
            student_first_char = student_answer.strip()[0].upper() if student_answer.strip() else ""
            correct_first_char = question.answer.strip()[0].upper() if question.answer.strip() else ""
            is_correct = student_first_char == correct_first_char

        if is_correct:
            correct_count += 1

        result = schemas.QuestionResult(
            questionContent=question.question_content,
            yourAnswer=student_answer,
            correctAnswer=question.answer,
            knowledgePoint=question.knowledge_point,
            explanation=question.explanation,
            wrongAnalysis=question.wrong_analysis,
            isCorrect=is_correct
        )
        full_results.append(result)

    # 计算最终得分（四舍五入为整数）
    final_score = round(correct_count * points_per_question)

    return final_score, full_results

# 学情分析相关函数
def get_class_analysis(db: Session, class_id: int, date: str = None) -> schemas.ClassAnalysis:
    """获取班级学情分析"""
    from datetime import datetime
    from collections import defaultdict, Counter
    
    # 如果没有指定日期，使用今天
    if not date:
        date = datetime.now().strftime("%Y-%m-%d")
    
    # 获取指定班级和日期的所有提交记录
    submissions = db.query(database.Submission).filter(
        database.Submission.class_id == class_id,
        database.Submission.submission_date == date
    ).all()
    
    if not submissions:
        return schemas.ClassAnalysis(
            class_id=class_id,
            total_students=0,
            average_score=0.0,
            score_distribution={},
            knowledge_point_analysis=[],
            top_performers=[],
            need_help_students=[],
            common_mistakes=[]
        )
    
    # 获取所有题目
    questions = get_questions(db)
    question_map = {q.id: q for q in questions}
    
    # 统计基本信息
    total_students = len(submissions)
    total_score = sum(s.score for s in submissions)
    average_score = round(total_score / total_students, 2)
    
    # 分数段分布
    score_ranges = {"90-100": 0, "80-89": 0, "70-79": 0, "60-69": 0, "0-59": 0}
    for submission in submissions:
        score = submission.score
        if score >= 90:
            score_ranges["90-100"] += 1
        elif score >= 80:
            score_ranges["80-89"] += 1
        elif score >= 70:
            score_ranges["70-79"] += 1
        elif score >= 60:
            score_ranges["60-69"] += 1
        else:
            score_ranges["0-59"] += 1
    
    # 知识点分析
    knowledge_point_stats = defaultdict(lambda: {"total": 0, "correct": 0, "wrong_answers": []})
    student_performance = []
    
    for submission in submissions:
        # 获取该学生的答题详情
        details = db.query(database.SubmissionDetail).filter(
            database.SubmissionDetail.submission_id == submission.id
        ).all()
        
        student_correct = 0
        student_total = len(details)
        weak_points = []
        
        for detail in details:
            question = question_map.get(detail.question_id)
            if question:
                knowledge_point = question.knowledge_point
                knowledge_point_stats[knowledge_point]["total"] += 1
                
                # 判断是否正确
                is_correct = False
                if detail.selected_answer and question.answer:
                    student_first_char = detail.selected_answer.strip()[0].upper() if detail.selected_answer.strip() else ""
                    correct_first_char = question.answer.strip()[0].upper() if question.answer.strip() else ""
                    is_correct = student_first_char == correct_first_char
                
                if is_correct:
                    knowledge_point_stats[knowledge_point]["correct"] += 1
                    student_correct += 1
                else:
                    knowledge_point_stats[knowledge_point]["wrong_answers"].append(detail.selected_answer)
                    weak_points.append(knowledge_point)
        
        # 学生表现
        accuracy_rate = round((student_correct / student_total * 100), 2) if student_total > 0 else 0
        student_performance.append(schemas.StudentPerformance(
            student_name=submission.student_name,
            score=submission.score,
            accuracy_rate=accuracy_rate,
            weak_knowledge_points=list(set(weak_points)),
            submission_time=submission.submission_time.strftime("%H:%M:%S")
        ))
    
    # 构建知识点分析
    knowledge_point_analysis = []
    for kp, stats in knowledge_point_stats.items():
        accuracy = round((stats["correct"] / stats["total"] * 100), 2) if stats["total"] > 0 else 0
        common_wrong = Counter(stats["wrong_answers"]).most_common(3)
        
        knowledge_point_analysis.append(schemas.KnowledgePointAnalysis(
            knowledge_point=kp,
            total_questions=stats["total"],
            correct_count=stats["correct"],
            incorrect_count=stats["total"] - stats["correct"],
            accuracy_rate=accuracy,
            common_wrong_answers=[answer for answer, count in common_wrong]
        ))
    
    # 排序学生表现
    student_performance.sort(key=lambda x: x.score, reverse=True)
    top_performers = student_performance[:5]  # 前5名
    need_help_students = [s for s in student_performance if s.score < 60]  # 需要帮助的学生
    
    # 常见错误（基于知识点正确率）
    common_mistakes = [kp.knowledge_point for kp in knowledge_point_analysis if kp.accuracy_rate < 60]
    
    return schemas.ClassAnalysis(
        class_id=class_id,
        total_students=total_students,
        average_score=average_score,
        score_distribution=score_ranges,
        knowledge_point_analysis=knowledge_point_analysis,
        top_performers=top_performers,
        need_help_students=need_help_students,
        common_mistakes=common_mistakes
    )


def get_student_learning_trajectory(db: Session, class_id: int, student_name: str):
    """获取学生学习轨迹"""
    from datetime import datetime
    
    # 获取该学生的所有提交记录
    submissions = db.query(database.Submission).filter(
        database.Submission.class_id == class_id,
        database.Submission.student_name == student_name
    ).order_by(database.Submission.submission_date.desc()).all()
    
    trajectory = []
    for submission in submissions:
        # 获取答题详情
        details = db.query(database.SubmissionDetail).filter(
            database.SubmissionDetail.submission_id == submission.id
        ).all()
        
        # 获取题目信息
        questions = get_questions(db)
        question_map = {q.id: q for q in questions}
        
        # 分析每次答题的知识点掌握情况
        knowledge_performance = {}
        for detail in details:
            question = question_map.get(detail.question_id)
            if question:
                kp = question.knowledge_point
                if kp not in knowledge_performance:
                    knowledge_performance[kp] = {'correct': 0, 'total': 0}
                
                knowledge_performance[kp]['total'] += 1
                
                # 判断是否正确
                is_correct = False
                if detail.selected_answer and question.answer:
                    student_first_char = detail.selected_answer.strip()[0].upper() if detail.selected_answer.strip() else ""
                    correct_first_char = question.answer.strip()[0].upper() if question.answer.strip() else ""
                    is_correct = student_first_char == correct_first_char
                
                if is_correct:
                    knowledge_performance[kp]['correct'] += 1
        
        # 计算知识点掌握率
        kp_mastery = {}
        for kp, stats in knowledge_performance.items():
            kp_mastery[kp] = round((stats['correct'] / stats['total'] * 100), 2) if stats['total'] > 0 else 0
        
        trajectory.append({
            'date': submission.submission_date,
            'time': submission.submission_time.strftime("%H:%M:%S"),
            'score': submission.score,
            'total_questions': len(details),
            'correct_count': sum(1 for detail in details if is_answer_correct(detail, question_map)),
            'knowledge_point_mastery': kp_mastery,
            'submission_id': submission.id
        })
    
    return {
        'student_name': student_name,
        'class_id': class_id,
        'total_attempts': len(trajectory),
        'trajectory': trajectory
    }


def get_error_analysis(db: Session, class_id: int, date: str = None):
    """获取错题深度分析"""
    from datetime import datetime
    from collections import defaultdict, Counter
    
    if not date:
        date = datetime.now().strftime("%Y-%m-%d")
    
    # 获取指定班级和日期的所有提交记录
    submissions = db.query(database.Submission).filter(
        database.Submission.class_id == class_id,
        database.Submission.submission_date == date
    ).all()
    
    if not submissions:
        return {
            'class_id': class_id,
            'date': date,
            'error_statistics': [],
            'question_error_ranking': [],
            'knowledge_point_errors': []
        }
    
    # 获取所有题目
    questions = get_questions(db)
    question_map = {q.id: q for q in questions}
    
    # 统计错误信息
    question_errors = defaultdict(lambda: {'total_attempts': 0, 'error_count': 0, 'wrong_answers': []})
    knowledge_point_errors = defaultdict(lambda: {'total_attempts': 0, 'error_count': 0, 'error_questions': set()})
    
    for submission in submissions:
        details = db.query(database.SubmissionDetail).filter(
            database.SubmissionDetail.submission_id == submission.id
        ).all()
        
        for detail in details:
            question = question_map.get(detail.question_id)
            if question:
                question_errors[detail.question_id]['total_attempts'] += 1
                knowledge_point_errors[question.knowledge_point]['total_attempts'] += 1
                
                # 判断是否错误
                is_correct = False
                if detail.selected_answer and question.answer:
                    student_first_char = detail.selected_answer.strip()[0].upper() if detail.selected_answer.strip() else ""
                    correct_first_char = question.answer.strip()[0].upper() if question.answer.strip() else ""
                    is_correct = student_first_char == correct_first_char
                
                if not is_correct:
                    question_errors[detail.question_id]['error_count'] += 1
                    question_errors[detail.question_id]['wrong_answers'].append(detail.selected_answer)
                    knowledge_point_errors[question.knowledge_point]['error_count'] += 1
                    knowledge_point_errors[question.knowledge_point]['error_questions'].add(detail.question_id)
    
    # 构建题目错误排行
    question_error_ranking = []
    for question_id, stats in question_errors.items():
        question = question_map.get(question_id)
        if question and stats['total_attempts'] > 0:
            error_rate = round((stats['error_count'] / stats['total_attempts'] * 100), 2)
            common_wrong = Counter(stats['wrong_answers']).most_common(3)
            
            question_error_ranking.append({
                'question_id': question_id,
                'question_content': question.content[:100] + '...' if len(question.content) > 100 else question.content,
                'knowledge_point': question.knowledge_point,
                'error_rate': error_rate,
                'error_count': stats['error_count'],
                'total_attempts': stats['total_attempts'],
                'common_wrong_answers': [{'answer': answer, 'count': count} for answer, count in common_wrong],
                'correct_answer': question.answer
            })
    
    # 按错误率排序
    question_error_ranking.sort(key=lambda x: x['error_rate'], reverse=True)
    
    # 构建知识点错误统计
    kp_error_stats = []
    for kp, stats in knowledge_point_errors.items():
        if stats['total_attempts'] > 0:
            error_rate = round((stats['error_count'] / stats['total_attempts'] * 100), 2)
            kp_error_stats.append({
                'knowledge_point': kp,
                'error_rate': error_rate,
                'error_count': stats['error_count'],
                'total_attempts': stats['total_attempts'],
                'error_question_count': len(stats['error_questions'])
            })
    
    kp_error_stats.sort(key=lambda x: x['error_rate'], reverse=True)
    
    return {
        'class_id': class_id,
        'date': date,
        'total_students': len(submissions),
        'question_error_ranking': question_error_ranking[:20],  # 前20个错误率最高的题目
        'knowledge_point_errors': kp_error_stats,
        'summary': {
            'most_difficult_question': question_error_ranking[0] if question_error_ranking else None,
            'most_problematic_knowledge_point': kp_error_stats[0] if kp_error_stats else None,
            'average_error_rate': round(sum(q['error_rate'] for q in question_error_ranking) / len(question_error_ranking), 2) if question_error_ranking else 0
        }
    }


def get_teaching_suggestions(db: Session, class_id: int, date: str = None):
    """获取教学建议"""
    from datetime import datetime
    
    # 获取班级分析数据
    class_analysis = get_class_analysis(db, class_id, date)
    error_analysis = get_error_analysis(db, class_id, date)
    
    suggestions = {
        'class_id': class_id,
        'date': date or datetime.now().strftime("%Y-%m-%d"),
        'overall_assessment': generate_overall_assessment(class_analysis),
        'priority_knowledge_points': generate_priority_suggestions(class_analysis.knowledge_point_analysis),
        'individual_guidance': generate_individual_guidance(class_analysis.need_help_students),
        'teaching_strategies': generate_teaching_strategies(class_analysis, error_analysis),
        'next_steps': generate_next_steps(class_analysis, error_analysis)
    }
    
    return suggestions


def generate_overall_assessment(class_analysis):
    """生成整体评估"""
    avg_score = class_analysis.average_score
    total_students = class_analysis.total_students
    excellent_count = class_analysis.score_distribution.get('90-100', 0)
    pass_count = total_students - class_analysis.score_distribution.get('0-59', 0)
    
    excellent_rate = round((excellent_count / total_students * 100), 2) if total_students > 0 else 0
    pass_rate = round((pass_count / total_students * 100), 2) if total_students > 0 else 0
    
    if avg_score >= 85:
        level = "优秀"
        description = "班级整体表现优秀，学生掌握情况良好"
    elif avg_score >= 75:
        level = "良好"
        description = "班级整体表现良好，大部分学生掌握基本知识"
    elif avg_score >= 60:
        level = "一般"
        description = "班级整体表现一般，需要加强基础知识教学"
    else:
        level = "较差"
        description = "班级整体表现较差，需要重点关注和辅导"
    
    return {
        'level': level,
        'description': description,
        'average_score': avg_score,
        'excellent_rate': excellent_rate,
        'pass_rate': pass_rate,
        'total_students': total_students
    }


def generate_priority_suggestions(knowledge_points):
    """生成重点知识点建议"""
    priority_points = [kp for kp in knowledge_points if kp.accuracy_rate < 70]
    priority_points.sort(key=lambda x: x.accuracy_rate)
    
    suggestions = []
    for kp in priority_points[:5]:  # 最多5个重点知识点
        urgency = "紧急" if kp.accuracy_rate < 40 else "重要" if kp.accuracy_rate < 60 else "关注"
        
        suggestion = {
            'knowledge_point': kp.knowledge_point,
            'accuracy_rate': kp.accuracy_rate,
            'urgency': urgency,
            'recommendations': [
                f"安排专项练习，重点讲解{kp.knowledge_point}的核心概念",
                f"分析常见错误，针对性纠正学生理解偏差",
                f"设计相关案例，帮助学生深入理解{kp.knowledge_point}"
            ]
        }
        
        if kp.accuracy_rate < 40:
            suggestion['recommendations'].append("建议课后个别辅导成绩较差的学生")
        
        suggestions.append(suggestion)
    
    return suggestions


def generate_individual_guidance(need_help_students):
    """生成个性化指导建议"""
    guidance = []
    for student in need_help_students[:10]:  # 最多10个需要帮助的学生
        guidance.append({
            'student_name': student.student_name,
            'score': student.score,
            'accuracy_rate': student.accuracy_rate,
            'weak_points': student.weak_knowledge_points,
            'suggestions': [
                "课后单独辅导，重点关注薄弱知识点",
                "提供额外练习材料，巩固基础概念",
                "建议与家长沟通，共同关注学生学习进度",
                "安排学习伙伴，进行同伴互助学习"
            ]
        })
    
    return guidance


def generate_teaching_strategies(class_analysis, error_analysis):
    """生成教学策略建议"""
    strategies = [
        {
            'category': '分层教学',
            'description': '根据学生表现分组教学',
            'actions': [
                '优秀学生可进行拓展学习和深度思考',
                '中等学生加强练习和理解',
                '后进学生重点补强基础知识'
            ]
        },
        {
            'category': '错题回顾',
            'description': '定期组织错题讲解',
            'actions': [
                '每周安排错题回顾课',
                '重点分析高频错误选项',
                '让学生分享解题思路和易错点'
            ]
        },
        {
            'category': '数据跟踪',
            'description': '建立学生学习档案',
            'actions': [
                '记录每次测验的知识点掌握情况',
                '跟踪学生进步轨迹',
                '定期评估教学效果'
            ]
        }
    ]
    
    return strategies


def generate_next_steps(class_analysis, error_analysis):
    """生成下一步行动建议"""
    next_steps = []
    
    # 基于班级平均分的建议
    if class_analysis.average_score < 60:
        next_steps.append({
            'priority': '高',
            'action': '降低教学难度，重新梳理基础知识',
            'timeline': '立即执行'
        })
    
    # 基于需要帮助的学生数量
    if len(class_analysis.need_help_students) > class_analysis.total_students * 0.3:
        next_steps.append({
            'priority': '高',
            'action': '安排集体补习，加强基础训练',
            'timeline': '本周内'
        })
    
    # 基于知识点掌握情况
    weak_kps = [kp for kp in class_analysis.knowledge_point_analysis if kp.accuracy_rate < 50]
    if weak_kps:
        next_steps.append({
            'priority': '中',
            'action': f'重点讲解{len(weak_kps)}个薄弱知识点',
            'timeline': '下周安排'
        })
    
    return next_steps


def is_answer_correct(detail, question_map):
    """判断答案是否正确"""
    question = question_map.get(detail.question_id)
    if not question or not detail.selected_answer or not question.answer:
        return False
    
    student_first_char = detail.selected_answer.strip()[0].upper() if detail.selected_answer.strip() else ""
    correct_first_char = question.answer.strip()[0].upper() if question.answer.strip() else ""
    return student_first_char == correct_first_char

# 大语言模型配置相关操作
def get_llm_configs(db: Session) -> List[database.LLMConfig]:
    """获取所有大语言模型配置"""
    return db.query(database.LLMConfig).order_by(database.LLMConfig.created_time.desc()).all()

def get_active_llm_config(db: Session) -> database.LLMConfig:
    """获取当前激活的大语言模型配置"""
    return db.query(database.LLMConfig).filter(database.LLMConfig.is_active == 1).first()

def get_llm_config(db: Session, config_id: int) -> database.LLMConfig:
    """根据ID获取大语言模型配置"""
    return db.query(database.LLMConfig).filter(database.LLMConfig.id == config_id).first()

def create_llm_config(db: Session, config: schemas.LLMConfigCreate) -> database.LLMConfig:
    """创建大语言模型配置"""
    # 如果设置为激活状态，先将其他配置设为非激活
    if config.is_active == 1:
        db.query(database.LLMConfig).update({"is_active": 0})
    
    db_config = database.LLMConfig(
        config_name=config.config_name,
        api_key=config.api_key,
        base_url=config.base_url,
        model_name=config.model_name,
        is_active=config.is_active
    )
    db.add(db_config)
    db.commit()
    db.refresh(db_config)
    return db_config

def update_llm_config(db: Session, config_id: int, config: schemas.LLMConfigUpdate) -> database.LLMConfig:
    """更新大语言模型配置"""
    db_config = db.query(database.LLMConfig).filter(database.LLMConfig.id == config_id).first()
    if not db_config:
        return None
    
    # 如果设置为激活状态，先将其他配置设为非激活
    if config.is_active == 1:
        db.query(database.LLMConfig).filter(database.LLMConfig.id != config_id).update({"is_active": 0})
    
    # 更新字段
    update_data = config.dict(exclude_unset=True)
    for field, value in update_data.items():
        setattr(db_config, field, value)
    
    db_config.updated_time = datetime.now()
    db.commit()
    db.refresh(db_config)
    return db_config

def delete_llm_config(db: Session, config_id: int) -> bool:
    """删除大语言模型配置"""
    db_config = db.query(database.LLMConfig).filter(database.LLMConfig.id == config_id).first()
    if db_config:
        db.delete(db_config)
        db.commit()
        return True
    return False

def activate_llm_config(db: Session, config_id: int) -> bool:
    """激活指定的大语言模型配置"""
    # 先将所有配置设为非激活
    db.query(database.LLMConfig).update({"is_active": 0})
    
    # 激活指定配置
    db_config = db.query(database.LLMConfig).filter(database.LLMConfig.id == config_id).first()
    if db_config:
        db_config.is_active = 1
        db_config.updated_time = datetime.now()
        db.commit()
        return True
    return False

# 题库管理相关CRUD操作
def get_question_banks(db: Session, skip: int = 0, limit: int = 100):
    """获取题库列表"""
    return db.query(database.QuestionBank).offset(skip).limit(limit).all()

def get_question_bank(db: Session, bank_id: int):
    """根据ID获取题库"""
    return db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()

def create_question_bank(db: Session, bank: schemas.QuestionBankCreate):
    """创建题库"""
    from datetime import datetime
    
    db_bank = database.QuestionBank(
        name=bank.name,
        description=bank.description,
        created_by=bank.created_by,
        created_time=datetime.now(),
        updated_time=datetime.now(),
        question_count=0
    )
    db.add(db_bank)
    db.commit()
    db.refresh(db_bank)
    return db_bank

def update_question_bank(db: Session, bank_id: int, bank: schemas.QuestionBankUpdate):
    """更新题库"""
    from datetime import datetime
    
    db_bank = db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()
    if not db_bank:
        return None
    
    if bank.name is not None:
        db_bank.name = bank.name
    if bank.description is not None:
        db_bank.description = bank.description
    
    db_bank.updated_time = datetime.now()
    db.commit()
    db.refresh(db_bank)
    return db_bank

def delete_question_bank(db: Session, bank_id: int):
    """删除题库"""
    db_bank = db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()
    if not db_bank:
        return False
    
    # 删除题库会级联删除相关题目（由数据库外键约束处理）
    db.delete(db_bank)
    db.commit()
    return True

def get_question_bank_with_questions(db: Session, bank_id: int):
    """获取题库及其包含的题目"""
    bank = db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()
    if not bank:
        return None
    
    questions = db.query(database.Question).filter(database.Question.question_bank_id == bank_id).all()
    return {
        "bank": bank,
        "questions": questions
    }

def update_question_bank_count(db: Session, bank_id: int):
    """更新题库的题目数量"""
    count = db.query(database.Question).filter(database.Question.question_bank_id == bank_id).count()
    db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).update({"question_count": count})
    db.commit()
    return count

def move_single_question(db: Session, question_id: int, bank_id: int = None):
    """移动单个题目到指定题库"""
    try:
        # 获取题目
        question = db.query(database.Question).filter(database.Question.id == question_id).first()
        if not question:
            return False

        # 记录原题库ID用于更新计数
        old_bank_id = question.question_bank_id

        # 更新题目的题库ID
        question.question_bank_id = bank_id
        question.updated_time = datetime.now()

        db.commit()

        # 更新相关题库的题目数量
        if old_bank_id:
            update_question_bank_count(db, old_bank_id)
        if bank_id:
            update_question_bank_count(db, bank_id)

        return True
    except Exception as e:
        db.rollback()
        logger.error(f"移动题目失败: {str(e)}")
        return False

def copy_question_to_bank(db: Session, question_id: int, bank_id: int):
    """复制题目到指定题库"""
    try:
        # 获取原题目
        original_question = db.query(database.Question).filter(database.Question.id == question_id).first()
        if not original_question:
            return None

        # 创建新题目（复制）
        new_question = database.Question(
            question_content=original_question.question_content,
            option_a=original_question.option_a,
            option_b=original_question.option_b,
            option_c=original_question.option_c,
            option_d=original_question.option_d,
            answer=original_question.answer,
            knowledge_point=original_question.knowledge_point,
            explanation=original_question.explanation,
            wrong_analysis=original_question.wrong_analysis,
            difficulty_level=original_question.difficulty_level,
            topic_category=original_question.topic_category,
            question_bank_id=bank_id,
            created_time=datetime.now(),
            updated_time=datetime.now()
        )

        db.add(new_question)
        db.commit()
        db.refresh(new_question)

        # 更新目标题库的题目数量
        update_question_bank_count(db, bank_id)

        logger.info(f"成功复制题目 {question_id} 到题库 {bank_id}，新题目ID: {new_question.id}")
        return new_question
    except Exception as e:
        db.rollback()
        logger.error(f"复制题目失败: {str(e)}")
        return None

def copy_questions_to_bank(db: Session, question_ids: List[int], bank_id: int):
    """批量复制题目到指定题库"""
    success_count = 0
    failed_count = 0

    try:
        for question_id in question_ids:
            try:
                # 获取原题目
                original_question = db.query(database.Question).filter(database.Question.id == question_id).first()
                if not original_question:
                    failed_count += 1
                    continue

                # 创建新题目（复制）
                new_question = database.Question(
                    question_content=original_question.question_content,
                    option_a=original_question.option_a,
                    option_b=original_question.option_b,
                    option_c=original_question.option_c,
                    option_d=original_question.option_d,
                    answer=original_question.answer,
                    knowledge_point=original_question.knowledge_point,
                    explanation=original_question.explanation,
                    wrong_analysis=original_question.wrong_analysis,
                    difficulty_level=original_question.difficulty_level,
                    topic_category=original_question.topic_category,
                    question_bank_id=bank_id,
                    created_time=datetime.now(),
                    updated_time=datetime.now()
                )

                db.add(new_question)
                success_count += 1

            except Exception as e:
                logger.error(f"复制题目 {question_id} 失败: {str(e)}")
                failed_count += 1

        db.commit()

        # 更新目标题库的题目数量
        update_question_bank_count(db, bank_id)

        logger.info(f"批量复制完成：成功 {success_count} 道，失败 {failed_count} 道")
        return success_count, failed_count

    except Exception as e:
        db.rollback()
        logger.error(f"批量复制题目失败: {str(e)}")
        return 0, len(question_ids)

def move_questions_to_bank(db: Session, question_ids: List[int], bank_id: int):
    """将题目移动到指定题库"""
    # 验证题库是否存在
    bank = db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()
    if not bank:
        return False

    # 更新题目的题库ID
    db.query(database.Question).filter(database.Question.id.in_(question_ids)).update(
        {"question_bank_id": bank_id}, synchronize_session=False
    )

    # 更新相关题库的题目数量
    affected_banks = db.query(database.QuestionBank).all()
    for affected_bank in affected_banks:
        update_question_bank_count(db, affected_bank.id)

    db.commit()
    return True

def get_active_question_bank(db: Session):
    """获取当前激活的题库"""
    return db.query(database.QuestionBank).filter(database.QuestionBank.is_active == 1).first()

def activate_question_bank(db: Session, bank_id: int):
    """激活指定题库（同时取消其他题库的激活状态）"""
    try:
        # 验证题库是否存在
        bank = db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).first()
        if not bank:
            return False
        
        # 取消所有题库的激活状态
        db.query(database.QuestionBank).update({"is_active": 0}, synchronize_session=False)
        
        # 激活指定题库
        db.query(database.QuestionBank).filter(database.QuestionBank.id == bank_id).update(
            {"is_active": 1}, synchronize_session=False
        )
        
        db.commit()
        
        # 清除相关缓存
        CacheManager.clear_questions_cache()
        
        return True
    except Exception as e:
        db.rollback()
        logger.error(f"激活题库失败: {str(e)}")
        return False

def get_questions_by_bank(db: Session, bank_id: int) -> List[database.Question]:
    """获取指定题库中的所有题目"""
    return db.query(database.Question).filter(database.Question.question_bank_id == bank_id).all()

def get_questions_from_active_bank(db: Session) -> List[database.Question]:
    """获取激活题库中的所有题目"""
    active_bank = get_active_question_bank(db)
    if not active_bank:
        # 如果没有激活的题库，返回空列表
        return []
    
    return db.query(database.Question).filter(
        database.Question.question_bank_id == active_bank.id
    ).all()

def export_submissions_to_pdf(db: Session, class_id: int):
    """导出指定班级的提交记录为PDF"""
    # 获取指定班级的所有提交记录
    submissions = db.query(database.Submission).filter(
        database.Submission.class_id == class_id
    ).all()
    
    return submissions

# 班级配置相关CRUD函数
def get_class_config(db: Session):
    """获取当前班级配置"""
    config = db.query(database.ClassConfig).first()
    if not config:
        # 如果没有配置，创建默认配置
        config = database.ClassConfig(current_class_id=1, class_name="1班")
        db.add(config)
        db.commit()
        db.refresh(config)
    return config

def delete_class_submissions(db: Session, class_id: int) -> bool:
    """删除指定班级的所有提交记录"""
    try:
        # 先删除提交详情
        submission_ids = db.query(database.Submission.id).filter(database.Submission.class_id == class_id).all()
        for (submission_id,) in submission_ids:
            db.query(database.SubmissionDetail).filter(database.SubmissionDetail.submission_id == submission_id).delete()
        
        # 再删除提交记录
        deleted_count = db.query(database.Submission).filter(database.Submission.class_id == class_id).delete()
        db.commit()
        return True
    except Exception as e:
        db.rollback()
        return False

def update_class_config(db: Session, current_class_id: int, class_name: str = None):
    """更新班级配置"""
    config = db.query(database.ClassConfig).first()
    if not config:
        config = database.ClassConfig(current_class_id=current_class_id, class_name=class_name)
        db.add(config)
    else:
        config.current_class_id = current_class_id
        if class_name:
            config.class_name = class_name
        config.updated_time = datetime.now()
    
    db.commit()
    db.refresh(config)
    return config

# 学生做题记录查看功能
def get_student_submission_records(db: Session, student_name: str = None, class_id: int = None, 
                                 start_date: str = None, end_date: str = None, 
                                 skip: int = 0, limit: int = 100) -> List[Dict[str, Any]]:
    """获取学生做题记录"""
    try:
        query = db.query(database.Submission)
        
        # 按条件筛选
        if student_name:
            query = query.filter(database.Submission.student_name.like(f"%{student_name}%"))
        if class_id:
            query = query.filter(database.Submission.class_id == class_id)
        if start_date:
            query = query.filter(database.Submission.submission_date >= start_date)
        if end_date:
            query = query.filter(database.Submission.submission_date <= end_date)
        
        # 按提交时间倒序排列
        submissions = query.order_by(database.Submission.submission_time.desc()).offset(skip).limit(limit).all()
        
        records = []
        for submission in submissions:
            records.append({
                "id": submission.id,
                "student_name": submission.student_name,
                "class_id": submission.class_id,
                "score": submission.score,
                "submission_date": submission.submission_date,
                "submission_time": submission.submission_time,
                "total_questions": submission.total_questions,
                "correct_answers": submission.correct_answers,
                "accuracy_rate": submission.accuracy_rate,
                "client_ip": submission.client_ip
            })
        
        return records
    except Exception as e:
        logger.error(f"获取学生做题记录失败: {str(e)}")
        raise

def get_student_submission_detail(db: Session, submission_id: int) -> Dict[str, Any]:
    """获取学生单次做题的详细记录"""
    try:
        submission = db.query(database.Submission).filter(database.Submission.id == submission_id).first()
        if not submission:
            return None
        
        # 获取详细答题记录
        details = db.query(database.SubmissionDetail).filter(
            database.SubmissionDetail.submission_id == submission_id
        ).all()
        
        # 获取题目信息
        question_details = []
        for detail in details:
            question = db.query(database.Question).filter(database.Question.id == detail.question_id).first()
            if question:
                question_details.append({
                    "question_id": question.id,
                    "question_content": question.question_content,
                    "option_a": question.option_a,
                    "option_b": question.option_b,
                    "option_c": question.option_c,
                    "option_d": question.option_d,
                    "correct_answer": question.answer,
                    "selected_answer": detail.selected_answer,
                    "is_correct": detail.is_correct,
                    "knowledge_point": question.knowledge_point,
                    "explanation": question.explanation,
                    "answer_time": detail.answer_time
                })
        
        return {
            "submission": {
                "id": submission.id,
                "student_name": submission.student_name,
                "class_id": submission.class_id,
                "score": submission.score,
                "submission_date": submission.submission_date,
                "submission_time": submission.submission_time,
                "total_questions": submission.total_questions,
                "correct_answers": submission.correct_answers,
                "accuracy_rate": submission.accuracy_rate,
                "client_ip": submission.client_ip
            },
            "question_details": question_details
        }
    except Exception as e:
        logger.error(f"获取学生做题详细记录失败: {str(e)}")
        raise

def get_student_statistics(db: Session, student_name: str, class_id: int = None) -> Dict[str, Any]:
    """获取学生做题统计信息"""
    try:
        query = db.query(database.Submission).filter(database.Submission.student_name == student_name)
        if class_id:
            query = query.filter(database.Submission.class_id == class_id)
        
        submissions = query.all()
        
        if not submissions:
            return {
                "student_name": student_name,
                "total_submissions": 0,
                "average_score": 0,
                "highest_score": 0,
                "lowest_score": 0,
                "total_questions_answered": 0,
                "total_correct_answers": 0,
                "overall_accuracy_rate": 0,
                "recent_submissions": []
            }
        
        total_submissions = len(submissions)
        scores = [s.score for s in submissions]
        total_questions = sum(s.total_questions for s in submissions)
        total_correct = sum(s.correct_answers for s in submissions)
        
        # 最近5次提交记录
        recent_submissions = sorted(submissions, key=lambda x: x.submission_time, reverse=True)[:5]
        recent_data = []
        for sub in recent_submissions:
            recent_data.append({
                "id": sub.id,
                "score": sub.score,
                "submission_date": sub.submission_date,
                "submission_time": sub.submission_time,
                "accuracy_rate": sub.accuracy_rate
            })
        
        return {
            "student_name": student_name,
            "total_submissions": total_submissions,
            "average_score": round(sum(scores) / total_submissions, 2),
            "highest_score": max(scores),
            "lowest_score": min(scores),
            "total_questions_answered": total_questions,
            "total_correct_answers": total_correct,
            "overall_accuracy_rate": round((total_correct / total_questions * 100) if total_questions > 0 else 0, 2),
            "recent_submissions": recent_data
        }
    except Exception as e:
        logger.error(f"获取学生统计信息失败: {str(e)}")
        raise